# Copyright 2024 Ant Group Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_cuda//cuda:defs.bzl", "cuda_library")
load("@yacl//bazel:yacl.bzl", "yacl_cc_library", "yacl_cc_test")

package(default_visibility = ["//visibility:public"])

test_suite(
    name = "paillier_gpu_tests",
)

yacl_cc_library(
    name = "paillier_gpu",
    hdrs = ["paillier.h"],
    defines = select(
        {
            "@rules_cuda//cuda:is_enabled": [
                "ENABLE_GPAILLIER=true",
            ],
            "//conditions:default": [
                "ENABLE_GPAILLIER=false",
            ],
        },
    ),
    tags = [
        "manual",
    ],
    deps = select({
        "@rules_cuda//cuda:is_enabled": [
            ":paillier_gpu_impl",
        ],
        "//conditions:default": [],
    }),
)

yacl_cc_library(
    name = "paillier_gpu_impl",
    tags = [
        "manual",
    ],
    target_compatible_with = [
        "@platforms//cpu:x86_64",
        "@platforms//os:linux",
    ],
    deps = [
        ":decryptor",
        ":encryptor",
        ":evaluator",
        ":key_generator",
    ],
)

cuda_library(
    name = "gpupaillier",
    srcs = [
        "gpulib/paillier.cu",
    ],
    hdrs = [
        "gpulib/error.h",
        "gpulib/gpu_paillier.h",
        "gpulib/gpupaillier.h",
    ],
    copts = [
        "-lcuda",
        "-lcudart",
        "-std=c++17",
    ],
    tags = [
        "manual",
    ],
    deps = [
        "@com_github_nvlabs_cgbn//:cgbn",
    ],
)

yacl_cc_library(
    name = "plaintext",
    hdrs = ["plaintext.h"],
    tags = [
        "manual",
    ],
    deps = [
        "//heu/library/algorithms/util",
        "@com_github_msgpack_msgpack//:msgpack",
    ],
)

yacl_cc_library(
    name = "secret_key",
    srcs = ["secret_key.cc"],
    hdrs = ["secret_key.h"],
    tags = [
        "manual",
    ],
    deps = [
        ":plaintext",
        "//heu/library/algorithms/util",
        "@com_github_msgpack_msgpack//:msgpack",
    ],
)

yacl_cc_library(
    name = "public_key",
    srcs = ["public_key.cc"],
    hdrs = ["public_key.h"],
    tags = [
        "manual",
    ],
    deps = [
        ":plaintext",
        "//heu/library/algorithms/util",
        "@com_github_msgpack_msgpack//:msgpack",
    ],
)

yacl_cc_library(
    name = "ciphertext",
    srcs = ["ciphertext.cc"],
    hdrs = ["ciphertext.h"],
    tags = [
        "manual",
    ],
    deps = [
        ":gpupaillier",
        ":plaintext",
        "//heu/library/algorithms/util",
        "@com_github_msgpack_msgpack//:msgpack",
    ],
)

yacl_cc_library(
    name = "key_generator",
    srcs = ["key_generator.cc"],
    hdrs = ["key_generator.h"],
    tags = [
        "manual",
    ],
    deps = [
        ":public_key",
        ":secret_key",
    ],
)

yacl_cc_library(
    name = "encryptor",
    srcs = ["encryptor.cc"],
    hdrs = ["encryptor.h"],
    tags = [
        "manual",
    ],
    deps = [
        ":ciphertext",
        ":gpupaillier",
        ":public_key",
        ":secret_key",
        "//heu/library/algorithms/util",
    ],
)

yacl_cc_library(
    name = "decryptor",
    srcs = ["decryptor.cc"],
    hdrs = ["decryptor.h"],
    tags = [
        "manual",
    ],
    deps = [
        ":ciphertext",
        ":gpupaillier",
        ":public_key",
        ":secret_key",
        "//heu/library/algorithms/util",
    ],
)

yacl_cc_library(
    name = "evaluator",
    srcs = ["evaluator.cc"],
    hdrs = ["evaluator.h"],
    tags = [
        "manual",
    ],
    deps = [
        ":ciphertext",
        ":encryptor",
        ":gpupaillier",
        ":public_key",
        "//heu/library/algorithms/util",
    ],
)

yacl_cc_test(
    name = "paillier_gpu_test",
    srcs = ["paillier_gpu_test.cc"],
    target_compatible_with = select(
        {
            "@rules_cuda//cuda:is_enabled": [
                "@platforms//cpu:x86_64",
                "@platforms//os:linux",
            ],
            "//conditions:default": [
                "@platforms//:incompatible",
            ],
        },
    ),
    deps = [
        ":paillier_gpu",
    ],
)

yacl_cc_test(
    name = "paillier_gpu_test_con",
    srcs = ["paillier_gpu_test_con.cc"],
    target_compatible_with = select(
        {
            "@rules_cuda//cuda:is_enabled": [
                "@platforms//cpu:x86_64",
                "@platforms//os:linux",
            ],
            "//conditions:default": [
                "@platforms//:incompatible",
            ],
        },
    ),
    deps = [
        ":paillier_gpu",
    ],
)
